{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tree species classification example\n",
    "This notebook gives an example of using a convolutional neural network to classify tree species in the Sierra Nevada forest."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we download the NEON data and label files from our dataset stored on Zenodo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import tqdm\n",
    "import argparse\n",
    "\n",
    "from wget import download\n",
    "\n",
    "from experiment.paths import *\n",
    "\n",
    "# make output directory if necessary\n",
    "if not os.path.exists('data'):\n",
    "    os.makedirs('data')\n",
    "\n",
    "files = [ 'Labels_Trimmed_Selective.CPG',\n",
    "          'Labels_Trimmed_Selective.dbf',\n",
    "          'Labels_Trimmed_Selective.prj',\n",
    "          'Labels_Trimmed_Selective.sbn',\n",
    "          'Labels_Trimmed_Selective.sbx',\n",
    "          'Labels_Trimmed_Selective.shp',\n",
    "          'Labels_Trimmed_Selective.shp.xml',\n",
    "          'Labels_Trimmed_Selective.shx',\n",
    "          'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif',\n",
    "          'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.aux.xml',\n",
    "          'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.enp',\n",
    "          'NEON_D17_TEAK_DP1_20170627_181333_reflectance.tif.ovr',\n",
    "          'D17_CHM_all.tfw',\n",
    "          'D17_CHM_all.tif',\n",
    "          'D17_CHM_all.tif.aux.xml',\n",
    "          'D17_CHM_all.tif.ovr',\n",
    "        ]\n",
    "\n",
    "for f in files:\n",
    "    if not os.path.exists('data/%s'%f):\n",
    "        print('downloading %s'%f)\n",
    "        download('https://zenodo.org/record/3468720/files/%s?download=1'%f,'data/%s'%f)\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we loads and co-register our data sources, including the hyperspectral image, the canopy height model, and the tree labels.  Then we build a dataset of patches and their corresponding labels and store it in a HDF5 file for easy use in Keras."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 15668/15668 [05:17<00:00, 49.38it/s]\n",
      "100%|██████████| 1909/1909 [00:39<00:00, 48.41it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tqdm\n",
    "from experiment.paths import *\n",
    "import os\n",
    "\n",
    "from canopy.vector_utils import *\n",
    "from canopy.extract import *\n",
    "import h5py as h5\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "# Load the metadata from the image.\n",
    "with rasterio.open(image_uri) as src:\n",
    "    image_meta = src.meta.copy()\n",
    "\n",
    "os.makedirs('example',exist_ok=True)\n",
    "\n",
    "seed = 0\n",
    "\n",
    "# Load the shapefile and transform it to the hypersectral image's CRS.\n",
    "polygons, labels = load_and_transform_shapefile(labels_shp_uri,'SP',image_meta['crs'])\n",
    "\n",
    "# Cluster polygons for use in stratified sampling\n",
    "centroids = np.stack([np.mean(np.array(poly['coordinates'][0]),axis=0) for poly in polygons])\n",
    "cluster_ids = KMeans(10).fit_predict(centroids)\n",
    "rasterize_shapefile(polygons, cluster_ids, image_meta, 'example/clusters.tiff')\n",
    "stratify = cluster_ids\n",
    "\n",
    "# alternative: stratify by species label\n",
    "# stratify = labels\n",
    "\n",
    "# Split up polygons into train, val, test here\n",
    "train_inds, test_inds = train_test_split(range(len(polygons)),test_size=0.1,random_state=seed,stratify=stratify)\n",
    "\n",
    "# Save ids of train,val,test polygons\n",
    "with open('example/' + train_ids_uri,'w') as f:\n",
    "    f.writelines([\"%d\\n\"%ind for ind in train_inds])\n",
    "with open('example/' + test_ids_uri,'w') as f:\n",
    "    f.writelines([\"%d\\n\"%ind for ind in test_inds])\n",
    "\n",
    "# Separate out polygons\n",
    "train_polygons = [polygons[ind] for ind in train_inds]\n",
    "train_labels = [labels[ind] for ind in train_inds]\n",
    "test_polygons = [polygons[ind] for ind in test_inds]\n",
    "test_labels = [labels[ind] for ind in test_inds]\n",
    "\n",
    "# Rasterize the shapefile to a TIFF.  Using LZW compression, the resulting file is pretty small.\n",
    "train_labels_raster = rasterize_shapefile(train_polygons, train_labels, image_meta, 'example/' + train_labels_uri)\n",
    "test_labels_raster = rasterize_shapefile(test_polygons, test_labels, image_meta, 'example/' + test_labels_uri)\n",
    "\n",
    "# Extract patches and labels\n",
    "patch_radius = 7\n",
    "height_threshold = 5\n",
    "train_image_patches, train_patch_labels = extract_patches(image_uri,patch_radius,chm_uri,height_threshold,'example/' + train_labels_uri)\n",
    "test_image_patches, test_patch_labels = extract_patches(image_uri,patch_radius,chm_uri,height_threshold,'example/' + test_labels_uri)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we set up and train the convolutional neural network model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class weights:  [ 0.74829501  2.29405615  1.21758085  0.48317187  0.7970631  24.93668831\n",
      "  2.45540281  0.61169959]\n",
      "(15, 15, 32) int16\n",
      "() uint8\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_7 (InputLayer)         (None, 15, 15, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_16 (Conv2D)           (None, 13, 13, 32)        9248      \n",
      "_________________________________________________________________\n",
      "conv2d_17 (Conv2D)           (None, 11, 11, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_18 (Conv2D)           (None, 9, 9, 128)         73856     \n",
      "_________________________________________________________________\n",
      "conv2d_19 (Conv2D)           (None, 7, 7, 128)         147584    \n",
      "_________________________________________________________________\n",
      "conv2d_20 (Conv2D)           (None, 5, 5, 128)         147584    \n",
      "_________________________________________________________________\n",
      "conv2d_21 (Conv2D)           (None, 3, 3, 128)         147584    \n",
      "_________________________________________________________________\n",
      "conv2d_22 (Conv2D)           (None, 1, 1, 128)         147584    \n",
      "_________________________________________________________________\n",
      "conv2d_23 (Conv2D)           (None, 1, 1, 8)           1032      \n",
      "_________________________________________________________________\n",
      "flatten_2 (Flatten)          (None, 8)                 0         \n",
      "=================================================================\n",
      "Total params: 692,968\n",
      "Trainable params: 692,968\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "augmenting images: 100%|██████████| 122888/122888 [00:01<00:00, 73026.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 110599 samples, validate on 12289 samples\n",
      "Epoch 1/20\n",
      "110599/110599 [==============================] - 41s 366us/step - loss: 1.7037 - acc: 0.5477 - val_loss: 1.1374 - val_acc: 0.8593\n",
      "\n",
      "Epoch 00001: val_acc improved from -inf to 0.85931, saving model to example/weights.hdf5\n",
      "Epoch 2/20\n",
      "110599/110599 [==============================] - 44s 399us/step - loss: 0.9412 - acc: 0.9016 - val_loss: 0.9102 - val_acc: 0.9249\n",
      "\n",
      "Epoch 00002: val_acc improved from 0.85931 to 0.92489, saving model to example/weights.hdf5\n",
      "Epoch 3/20\n",
      "110599/110599 [==============================] - 44s 397us/step - loss: 0.8229 - acc: 0.9410 - val_loss: 0.8023 - val_acc: 0.9571\n",
      "\n",
      "Epoch 00003: val_acc improved from 0.92489 to 0.95712, saving model to example/weights.hdf5\n",
      "Epoch 4/20\n",
      "110599/110599 [==============================] - 44s 395us/step - loss: 0.7664 - acc: 0.9590 - val_loss: 0.7577 - val_acc: 0.9675\n",
      "\n",
      "Epoch 00004: val_acc improved from 0.95712 to 0.96745, saving model to example/weights.hdf5\n",
      "Epoch 5/20\n",
      "110599/110599 [==============================] - 44s 397us/step - loss: 0.7245 - acc: 0.9712 - val_loss: 0.7225 - val_acc: 0.9788\n",
      "\n",
      "Epoch 00005: val_acc improved from 0.96745 to 0.97876, saving model to example/weights.hdf5\n",
      "Epoch 6/20\n",
      "110599/110599 [==============================] - 44s 400us/step - loss: 0.6950 - acc: 0.9795 - val_loss: 0.6946 - val_acc: 0.9841\n",
      "\n",
      "Epoch 00006: val_acc improved from 0.97876 to 0.98413, saving model to example/weights.hdf5\n",
      "Epoch 7/20\n",
      "110599/110599 [==============================] - 45s 404us/step - loss: 0.6772 - acc: 0.9846 - val_loss: 0.6740 - val_acc: 0.9900\n",
      "\n",
      "Epoch 00007: val_acc improved from 0.98413 to 0.98999, saving model to example/weights.hdf5\n",
      "Epoch 8/20\n",
      "110599/110599 [==============================] - 45s 404us/step - loss: 0.6574 - acc: 0.9896 - val_loss: 0.6548 - val_acc: 0.9941\n",
      "\n",
      "Epoch 00008: val_acc improved from 0.98999 to 0.99406, saving model to example/weights.hdf5\n",
      "Epoch 9/20\n",
      "110599/110599 [==============================] - 45s 409us/step - loss: 0.6461 - acc: 0.9918 - val_loss: 0.6478 - val_acc: 0.9924\n",
      "\n",
      "Epoch 00009: val_acc did not improve from 0.99406\n",
      "Epoch 10/20\n",
      "110599/110599 [==============================] - 46s 415us/step - loss: 0.6434 - acc: 0.9918 - val_loss: 0.6347 - val_acc: 0.9934\n",
      "\n",
      "Epoch 00010: val_acc did not improve from 0.99406\n",
      "Epoch 11/20\n",
      "110599/110599 [==============================] - 43s 389us/step - loss: 0.6213 - acc: 0.9961 - val_loss: 0.6205 - val_acc: 0.9970\n",
      "\n",
      "Epoch 00011: val_acc improved from 0.99406 to 0.99699, saving model to example/weights.hdf5\n",
      "Epoch 12/20\n",
      "110599/110599 [==============================] - 43s 391us/step - loss: 0.6118 - acc: 0.9969 - val_loss: 0.6206 - val_acc: 0.9928\n",
      "\n",
      "Epoch 00012: val_acc did not improve from 0.99699\n",
      "Epoch 13/20\n",
      "110599/110599 [==============================] - 44s 394us/step - loss: 0.6026 - acc: 0.9976 - val_loss: 0.6020 - val_acc: 0.9972\n",
      "\n",
      "Epoch 00013: val_acc improved from 0.99699 to 0.99723, saving model to example/weights.hdf5\n",
      "Epoch 14/20\n",
      "110599/110599 [==============================] - 42s 384us/step - loss: 0.5938 - acc: 0.9982 - val_loss: 0.5922 - val_acc: 0.9980\n",
      "\n",
      "Epoch 00014: val_acc improved from 0.99723 to 0.99805, saving model to example/weights.hdf5\n",
      "Epoch 15/20\n",
      "110599/110599 [==============================] - 45s 407us/step - loss: 0.5853 - acc: 0.9986 - val_loss: 0.5822 - val_acc: 0.9988\n",
      "\n",
      "Epoch 00015: val_acc improved from 0.99805 to 0.99878, saving model to example/weights.hdf5\n",
      "Epoch 16/20\n",
      "110599/110599 [==============================] - 45s 410us/step - loss: 0.5771 - acc: 0.9989 - val_loss: 0.5746 - val_acc: 0.9989\n",
      "\n",
      "Epoch 00016: val_acc improved from 0.99878 to 0.99886, saving model to example/weights.hdf5\n",
      "Epoch 17/20\n",
      "110599/110599 [==============================] - 44s 399us/step - loss: 0.5692 - acc: 0.9990 - val_loss: 0.5686 - val_acc: 0.9985\n",
      "\n",
      "Epoch 00017: val_acc did not improve from 0.99886\n",
      "Epoch 18/20\n",
      "110599/110599 [==============================] - 44s 402us/step - loss: 0.5615 - acc: 0.9993 - val_loss: 0.5595 - val_acc: 0.9987\n",
      "\n",
      "Epoch 00018: val_acc did not improve from 0.99886\n",
      "Epoch 19/20\n",
      "110599/110599 [==============================] - 46s 411us/step - loss: 0.5538 - acc: 0.9993 - val_loss: 0.5559 - val_acc: 0.9981\n",
      "\n",
      "Epoch 00019: val_acc did not improve from 0.99886\n",
      "Epoch 20/20\n",
      "110599/110599 [==============================] - 43s 389us/step - loss: 0.5464 - acc: 0.9995 - val_loss: 0.5446 - val_acc: 0.9991\n",
      "\n",
      "Epoch 00020: val_acc improved from 0.99886 to 0.99910, saving model to example/weights.hdf5\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f3f1c6d8ba8>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import h5py as h5\n",
    "from tqdm import tqdm, trange\n",
    "import os\n",
    "import sys\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
    "from tensorflow.keras.optimizers import SGD, Adam\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from joblib import dump, load\n",
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from canopy.model import PatchClassifier\n",
    "from experiment.paths import *\n",
    "\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "K.set_session(sess)\n",
    "\n",
    "np.random.seed(0)\n",
    "tf.set_random_seed(0)\n",
    "\n",
    "out = 'example'\n",
    "lr = 0.0001\n",
    "epochs = 20\n",
    "\n",
    "x_all = train_image_patches\n",
    "y_all = train_patch_labels\n",
    "\n",
    "class_weights = compute_class_weight('balanced',range(8),y_all)\n",
    "print('class weights: ',class_weights)\n",
    "class_weight_dict = {}\n",
    "for i in range(8):\n",
    "    class_weight_dict[i] = class_weights[i]\n",
    "\n",
    "def estimate_pca():\n",
    "    x_samples = x_all[:,7,7]\n",
    "    pca = PCA(32,whiten=True)\n",
    "    pca.fit(x_samples)\n",
    "    return pca\n",
    "\n",
    "\"\"\"Normalize training data\"\"\"\n",
    "pca = estimate_pca()\n",
    "dump(pca,out + '/pca.joblib')\n",
    "\n",
    "x_shape = x_all.shape[1:]\n",
    "x_dtype = x_all.dtype\n",
    "y_shape = y_all.shape[1:]\n",
    "y_dtype = y_all.dtype\n",
    "x_shape = x_shape[:-1] + (pca.n_components_,)\n",
    "\n",
    "print(x_shape, x_dtype)\n",
    "print(y_shape, y_dtype)\n",
    "\n",
    "classifier = PatchClassifier(num_classes=8)\n",
    "model = classifier.get_patch_model(x_shape)\n",
    "\n",
    "print(model.summary())\n",
    "\n",
    "model.compile(optimizer=SGD(lr,momentum=0.9), loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "\n",
    "def apply_pca(x):\n",
    "    N,H,W,C = x.shape\n",
    "    x = np.reshape(x,(-1,C))\n",
    "    x = pca.transform(x)\n",
    "    x = np.reshape(x,(-1,H,W,x.shape[-1]))\n",
    "    return x\n",
    "\n",
    "checkpoint = ModelCheckpoint(filepath=out + '/' + weights_uri, monitor='val_acc', verbose=True, save_best_only=True, save_weights_only=True)\n",
    "reducelr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=10, verbose=1, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)\n",
    "\n",
    "x_all = apply_pca(x_all)\n",
    "\n",
    "def augment_images(x,y):\n",
    "    x_aug = []\n",
    "    y_aug = []\n",
    "    with tqdm(total=len(x)*8,desc='augmenting images') as pbar:\n",
    "        for rot in range(4):\n",
    "            for flip in range(2):\n",
    "                for patch,label in zip(x,y):\n",
    "                    patch = np.rot90(patch,rot)\n",
    "                    if flip:\n",
    "                        patch = np.flip(patch,axis=0)\n",
    "                        patch = np.flip(patch,axis=1)\n",
    "                    x_aug.append(patch)\n",
    "                    y_aug.append(label)\n",
    "                    pbar.update(1)\n",
    "    return np.stack(x_aug,axis=0), np.stack(y_aug,axis=0)\n",
    "\n",
    "x_all, y_all = augment_images(x_all,y_all)\n",
    "\n",
    "train_inds, val_inds = train_test_split(range(len(x_all)),test_size=0.1,random_state=0)\n",
    "x_train = np.stack([x_all[i] for i in train_inds],axis=0)\n",
    "y_train = np.stack([y_all[i] for i in train_inds],axis=0)\n",
    "x_val = np.stack([x_all[i] for i in val_inds],axis=0)\n",
    "y_val = np.stack([y_all[i] for i in val_inds],axis=0)\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "model.fit( x_train, y_train,\n",
    "           epochs=epochs,\n",
    "           batch_size=batch_size,\n",
    "           validation_data=(x_val,y_val),\n",
    "           verbose=1,\n",
    "           callbacks=[checkpoint,reducelr],\n",
    "           class_weight=class_weight_dict)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we run the trained model on the full image in tiles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/774 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metadata for image\n",
      "nodata:\n",
      "None\n",
      "\n",
      "transform:\n",
      "| 1.00, 0.00, 319344.00|\n",
      "| 0.00,-1.00, 4101691.00|\n",
      "| 0.00, 0.00, 1.00|\n",
      "\n",
      "width:\n",
      "1028\n",
      "\n",
      "count:\n",
      "426\n",
      "\n",
      "height:\n",
      "10948\n",
      "\n",
      "dtype:\n",
      "int16\n",
      "\n",
      "crs:\n",
      "+init=epsg:32611\n",
      "\n",
      "driver:\n",
      "GTiff\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|████████▉ | 688/774 [02:58<00:20,  4.23it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "from math import floor, ceil\n",
    "import tqdm\n",
    "from joblib import dump, load\n",
    "\n",
    "import rasterio\n",
    "from rasterio.windows import Window\n",
    "from rasterio.enums import Resampling\n",
    "from rasterio.vrt import WarpedVRT\n",
    "\n",
    "from canopy.model import PatchClassifier\n",
    "from experiment.paths import *\n",
    "\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "K.set_session(sess)\n",
    "\n",
    "pca = load(out + '/pca.joblib')\n",
    "\n",
    "# \"no data value\" for labels\n",
    "label_ndv = 255\n",
    "\n",
    "# radius of square patch (side of patch = 2*radius+1)\n",
    "patch_radius = 7\n",
    "\n",
    "# height threshold for CHM -- pixels at or below this height will be discarded\n",
    "height_threshold = 5\n",
    "\n",
    "# tile size for processing\n",
    "tile_size = 128\n",
    "\n",
    "# tile size with padding\n",
    "padded_tile_size = tile_size + 2*patch_radius\n",
    "\n",
    "# open the hyperspectral or RGB image\n",
    "image = rasterio.open(image_uri)\n",
    "image_meta = image.meta.copy()\n",
    "image_ndv = image.meta['nodata']\n",
    "image_width = image.meta['width']\n",
    "image_height = image.meta['height']\n",
    "image_channels = image.meta['count']\n",
    "\n",
    "# load model\n",
    "input_shape = (padded_tile_size,padded_tile_size,pca.n_components_)\n",
    "tree_classifier = PatchClassifier(num_classes=8)\n",
    "training_model = tree_classifier.get_patch_model(input_shape)\n",
    "training_model.load_weights(out + '/' + weights_uri)\n",
    "model = tree_classifier.get_convolutional_model(input_shape)\n",
    "\n",
    "# calculate number of tiles\n",
    "num_tiles_y = ceil(image_height / float(tile_size))\n",
    "num_tiles_x = ceil(image_width / float(tile_size))\n",
    "\n",
    "print('Metadata for image')\n",
    "for key in image_meta.keys():\n",
    "    print('%s:'%key)\n",
    "    print(image_meta[key])\n",
    "    print()\n",
    "\n",
    "# create predicted label raster\n",
    "predict_meta = image_meta.copy()\n",
    "predict_meta['dtype'] = 'uint8'\n",
    "predict_meta['nodata'] = label_ndv\n",
    "predict_meta['count'] = 1\n",
    "predict = rasterio.open(out + '/' + predict_uri, 'w', compress='lzw', **predict_meta)\n",
    "\n",
    "# open the CHM\n",
    "chm = rasterio.open(chm_uri)\n",
    "chm_vrt = WarpedVRT(chm, crs=image.meta['crs'], transform=image.meta['transform'], width=image.meta['width'], height=image.meta['height'],\n",
    "                   resampling=Resampling.bilinear)\n",
    "\n",
    "# dilation kernel\n",
    "kernel = np.ones((patch_radius*2+1,patch_radius*2+1),dtype=np.uint8)\n",
    "\n",
    "def apply_pca(x):\n",
    "    N,H,W,C = x.shape\n",
    "    x = np.reshape(x,(-1,C))\n",
    "    x = pca.transform(x)\n",
    "    x = np.reshape(x,(-1,H,W,x.shape[-1]))\n",
    "    return x\n",
    "\n",
    "# go through all tiles of input image\n",
    "# run convolutional model on tile\n",
    "# write labels to output label raster\n",
    "with tqdm.tqdm(total=num_tiles_y*num_tiles_x) as pbar:\n",
    "    for y in range(patch_radius,image_height-patch_radius,tile_size):\n",
    "        for x in range(patch_radius,image_width-patch_radius,tile_size):\n",
    "            pbar.update(1)\n",
    "\n",
    "            window = Window(x-patch_radius,y-patch_radius,padded_tile_size,padded_tile_size)\n",
    "\n",
    "            # get tile from chm\n",
    "            chm_tile = chm_vrt.read(1,window=window)\n",
    "            if chm_tile.shape[0] != padded_tile_size or chm_tile.shape[1] != padded_tile_size:\n",
    "                pad = ((0,padded_tile_size-chm_tile.shape[0]),(0,padded_tile_size-chm_tile.shape[1]))\n",
    "                chm_tile = np.pad(chm_tile,pad,mode='constant',constant_values=0)\n",
    "          \n",
    "            chm_tile = np.expand_dims(chm_tile,axis=0)\n",
    "            chm_bad = chm_tile <= height_threshold\n",
    "\n",
    "            # get tile from image\n",
    "            image_tile = image.read(window=window)\n",
    "            image_pad_y = padded_tile_size-image_tile.shape[1]\n",
    "            image_pad_x = padded_tile_size-image_tile.shape[2]\n",
    "            output_window = Window(x,y,tile_size-image_pad_x,tile_size-image_pad_y)\n",
    "            if image_tile.shape[1] != padded_tile_size or image_tile.shape[2] != padded_tile_size:\n",
    "                pad = ((0,0),(0,image_pad_y),(0,image_pad_x))\n",
    "                image_tile = np.pad(image_tile,pad,mode='constant',constant_values=-1)\n",
    "\n",
    "            # re-order image tile to have height,width,channels\n",
    "            image_tile = np.transpose(image_tile,axes=[1,2,0])\n",
    "\n",
    "            # add batch axis\n",
    "            image_tile = np.expand_dims(image_tile,axis=0)\n",
    "            image_bad = np.any(image_tile < 0,axis=-1)\n",
    "\n",
    "            image_tile = image_tile.astype('float32')\n",
    "            image_tile = apply_pca(image_tile)\n",
    "            \n",
    "            # run tile through network\n",
    "            predict_tile = np.argmax(model.predict(image_tile),axis=-1).astype('uint8')\n",
    "\n",
    "            # dilate mask\n",
    "            image_bad = cv2.dilate(image_bad.astype('uint8'),kernel).astype('bool')\n",
    "\n",
    "            # set bad pixels to NDV\n",
    "            predict_tile[chm_bad[:,patch_radius:-patch_radius,patch_radius:-patch_radius]] = label_ndv\n",
    "            predict_tile[image_bad[:,patch_radius:-patch_radius,patch_radius:-patch_radius]] = label_ndv\n",
    "\n",
    "            # undo padding\n",
    "            if image_pad_y > 0:\n",
    "                predict_tile = predict_tile[:,:-image_pad_y,:]\n",
    "            if image_pad_x > 0:\n",
    "                predict_tile = predict_tile[:,:,:-image_pad_x]\n",
    "\n",
    "            # write to file\n",
    "            predict.write(predict_tile,window=output_window)\n",
    "\n",
    "image.close()\n",
    "chm.close()\n",
    "predict.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we run an analysis of the classification performance on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "classification report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.62      0.89      0.73         9\n",
      "          1       0.00      0.00      0.00         1\n",
      "          2       0.82      1.00      0.90         9\n",
      "          3       1.00      0.88      0.93        16\n",
      "          4       0.88      1.00      0.93         7\n",
      "          5       0.00      0.00      0.00         2\n",
      "          6       0.56      0.71      0.63         7\n",
      "          7       1.00      0.67      0.80        21\n",
      "\n",
      "avg / total       0.83      0.79      0.80        72\n",
      "\n",
      "confusion matrix:\n",
      "[[ 8  1  0  0  0  0  0  0]\n",
      " [ 1  0  0  0  0  0  0  0]\n",
      " [ 0  0  9  0  0  0  0  0]\n",
      " [ 1  0  0 14  1  0  0  0]\n",
      " [ 0  0  0  0  7  0  0  0]\n",
      " [ 0  0  0  0  0  0  2  0]\n",
      " [ 0  0  0  0  0  2  5  0]\n",
      " [ 3  0  2  0  0  0  2 14]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import rasterio\n",
    "from rasterio.windows import Window\n",
    "from rasterio.enums import Resampling\n",
    "from rasterio.vrt import WarpedVRT\n",
    "from rasterio.mask import mask\n",
    "\n",
    "from shapely.geometry import Polygon\n",
    "from shapely.geometry import Point\n",
    "from shapely.geometry import mapping\n",
    "\n",
    "import tqdm\n",
    "\n",
    "from math import floor, ceil\n",
    "\n",
    "from experiment.paths import *\n",
    "\n",
    "from canopy.vector_utils import *\n",
    "from canopy.extract import *\n",
    "\n",
    "import sklearn.metrics\n",
    "from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score\n",
    "\n",
    "train_inds = np.loadtxt(out + '/' + train_ids_uri,dtype='int32')\n",
    "test_inds = np.loadtxt(out + '/' + test_ids_uri,dtype='int32')\n",
    "\n",
    "# Load the metadata from the image.\n",
    "with rasterio.open(image_uri) as src:\n",
    "    image_meta = src.meta.copy()\n",
    "\n",
    "# Load the shapefile and transform it to the hypersectral image's CRS.\n",
    "polygons, labels = load_and_transform_shapefile(labels_shp_uri,'SP',image_meta['crs'])\n",
    "\n",
    "train_labels = [labels[ind] for ind in train_inds]\n",
    "test_labels = [labels[ind] for ind in test_inds]\n",
    "\n",
    "# open predicted label raster\n",
    "predict = rasterio.open(out + '/' + predict_uri)\n",
    "predict_raster = predict.read(1)\n",
    "ndv = predict.meta['nodata']\n",
    "\n",
    "def get_predictions(inds):\n",
    "    preds = []\n",
    "    for ind in inds:\n",
    "        poly = [mapping(Polygon(polygons[ind]['coordinates'][0]))]\n",
    "        out_image, out_transform = mask(predict, poly, crop=False)\n",
    "        out_image = out_image[0]\n",
    "        \n",
    "        label = labels[ind]\n",
    "\n",
    "        rows, cols = np.where(out_image != ndv)\n",
    "        predict_labels = []\n",
    "        for row, col in zip(rows,cols):\n",
    "            predict_labels.append(predict_raster[row,col])\n",
    "        predict_labels = np.array(predict_labels)\n",
    "        \n",
    "        hist = [np.count_nonzero(predict_labels==i) for i in range(8)]\n",
    "        majority_label = np.argmax(hist)\n",
    "        preds.append(majority_label)\n",
    "    return preds\n",
    "\n",
    "def calculate_confusion_matrix(labels,preds):\n",
    "    mat = np.zeros((8,8),dtype='int32')\n",
    "    for label,pred in zip(labels,preds):\n",
    "        mat[label,pred] += 1\n",
    "    return mat\n",
    "\n",
    "def calculate_fscore(labels,preds):\n",
    "    return sklearn.metrics.f1_score(labels,preds,average='micro')\n",
    "\n",
    "test_preds = get_predictions(test_inds)\n",
    " \n",
    "report = classification_report(test_labels, test_preds)\n",
    "mat = confusion_matrix(test_labels,test_preds)\n",
    "print('classification report:')\n",
    "print(report)\n",
    "print('confusion matrix:')\n",
    "print(mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
